{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Neural Style\n",
    "\n",
    "\n",
    "[Neural style transfer](https://arxiv.org/abs/1508.06576) 就是使用神经网络来完成将一张指定图片的风格应用到另外一张图片上的算法。\n",
    "\n",
    "\n",
    "![neural-style](neural-style.png)\n",
    "\n",
    "\n",
    "\n",
    "它的流程如下所示：\n",
    "\n",
    "\n",
    "![neural-style2](neural-style2.jpg)\n",
    "\n",
    "\n",
    "\n",
    "文字部分的描述就是：\n",
    "\n",
    "\n",
    "1、准备好内容图片 $content$ 和样式图片$style$。\n",
    "\n",
    "2、初始化合成图片 $x$ ，然后把 $x$ 、$content$ 、$style$ 都扔进 VGG19提取特征，其中的 $x$ 和 $style$ 的层1, 2, 4来对样式层进行匹配； $x$ 和 $content$ 的层3作为内容层来匹配。\n",
    "\n",
    "3、使用 $style \\_ Loss$ 和 $content \\_ Loss$ 来分别计算样式损失和内容损失，并对输入 $x$ 求导，导数记作 $g$\n",
    "\n",
    "4、更新输入 $x$ 的值 $x = x - lr * g$\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "说明部分图片来源：https://zh.gluon.ai/chapter_computer-vision/neural-style.html"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 结果展示\n",
    "\n",
    "\n",
    "\n",
    "![result](./result.jpg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.autograd import Variable\n",
    "import torch.nn as nn\n",
    "from torchvision import models, transforms\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "import torchvision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Config(object):\n",
    "    IMAGENET_MEAN = [0.485, 0.456, 0.406]  # IMAGENET中的图片归一化\n",
    "    IMAGENET_STD = [0.229, 0.224,  0.225]  # IMAGENET中的图片归一化\n",
    "    imsize = 512  # 统一 content image 和 style image 的 size\n",
    "    style_image_path = '../images/candy.jpg' # 样式图片的路径\n",
    "    content_image_path = '../images/hoovertowernight.jpg' # 内容图片的路径\n",
    "    DOWNLOAD = True  # 是否下载预训练模型\n",
    "    lr = 0.05  # 学习速率\n",
    "    epoches = 5000  # 训练epoch\n",
    "    show_epoch = 5  # 显示损失的epoch\n",
    "    sample_epoch = 500  # 采样的epoch\n",
    "    c_weight = 10  # content_loss 的权重\n",
    "    s_weight = 1500  # style_loss的权重\n",
    "    use_cuda = torch.cuda.is_available()\n",
    "    dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor\n",
    "    \n",
    "# 参数配置实例化\n",
    "config = Config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# 定义预处理函数\n",
    "# 用于把原始图片进行归一化，转换成卷积神经网络可以接收的输入格式\n",
    "def preprocess(image_path, trasform=None):\n",
    "    image = Image.open(image_path)\n",
    "    image = trasform(image)\n",
    "    image = image.unsqueeze(0)\n",
    "    return image.type(config.dtype)\n",
    "\n",
    "transformer = transforms.Compose([\n",
    "        transforms.Scale(config.imsize),\n",
    "        transforms.ToTensor()\n",
    "        transforms.Normalize(mean = config.IMAGENET_MEAN,std = config.IMAGENET_STD)\n",
    "    ])\n",
    "\n",
    "pltshow = transforms.ToPILImage()\n",
    "\n",
    "# 图片展示\n",
    "def imshow(tensor, title=None):\n",
    "    image = tensor.clone().cpu()\n",
    "    image = image.view(3, config.imsize, config.imsize)\n",
    "    image = pltshow(image)\n",
    "    plt.imshow(image)\n",
    "    if title is not None:\n",
    "        plt.title(title)\n",
    "    plt.pause(0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 定义模型\n",
    "class VGG_extract(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(VGG_extract, self).__init__()\n",
    "        self.style_layers = [0, 5, 10, 19, 28]\n",
    "        self.content_layers = [25]\n",
    "        self.net = models.vgg19(pretrained=config.DOWNLOAD).features\n",
    "\n",
    "    def forward(self, x):\n",
    "        # 由于我们只需要指定层的输出，所以，我们同时构建特征提取函数只保留制定层的值\n",
    "        contents = []\n",
    "        styles = []\n",
    "        for i in range(len(self.net)):\n",
    "            x = self.net[i](x)\n",
    "            if i in self.style_layers:\n",
    "                styles.append(x)\n",
    "            if i in self.content_layers:\n",
    "                contents.append(x)\n",
    "        return contents, styles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# 构建损失函数 （最麻烦的地方了...）\n",
    "\n",
    "# 内容匹配只涉及一层，所以可以看成回归问题，直接使用均方误差。\n",
    "def content_loss(y_hat, y):\n",
    "    con_loss = 0\n",
    "    for y_i, y_j in zip(y_hat, y):\n",
    "        con_loss += torch.mean((y_i - y_j) ** 2)\n",
    "    return con_loss\n",
    "\n",
    "\n",
    "# 样式匹配则是通过拟合Gram矩阵\n",
    "def gram(y):\n",
    "    b, c, h, w = y.size()\n",
    "    y = y.view(c, h * w)\n",
    "    norm = h * w\n",
    "    return torch.mm(y, y.t()) / norm\n",
    "\n",
    "# 计算样式损失\n",
    "def style_loss(y_hat, y):\n",
    "    sty_loss = 0\n",
    "    for y_i, y_j in zip(y_hat, y):\n",
    "        sty_loss += torch.mean((gram(y_i) - gram(y_j)) ** 2)\n",
    "    return sty_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# 开始训练\n",
    "style = preprocess(config.style_image_path, transformer)\n",
    "\n",
    "content = preprocess(config.content_image_path, transformer)\n",
    "\n",
    "target = Variable(content.clone(), requires_grad=True)  # 这就是我们需要更新的是输入 x\n",
    "\n",
    "optimizer = torch.optim.Adam([target], lr=config.lr, betas=[0.5, 0.999])\n",
    "\n",
    "# 加载模型\n",
    "vgg = VGG_extract()\n",
    "\n",
    "if config.use_cuda:\n",
    "    vgg = vgg.cuda()\n",
    "\n",
    "for epcho in range(config.epoches):\n",
    "    # 获取特征\n",
    "    t_c_fea, t_s_fea = vgg(target)\n",
    "    c_fea, _ = vgg(Variable(content))\n",
    "    _, s_fea = vgg(Variable(style))\n",
    "\n",
    "    # 计算损失\n",
    "    c_loss = content_loss(t_c_fea, c_fea)\n",
    "    s_loss = style_loss(t_s_fea, s_fea)\n",
    "    loss = config.c_weight * c_loss + config.s_weight * s_loss\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    print(epcho)\n",
    "    if (epcho + 1) % config.show_epoch == 0:\n",
    "        print('Epcho [%d/%d], Content Loss: %.4f, Style Loss: %.4f'\n",
    "              % (epcho + 1, config.epoches, c_loss.data[0], s_loss.data[0]))\n",
    "        # 保存图片\n",
    "    if (epcho + 1) % config.sample_epoch == 0:\n",
    "        denorm = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4.44))\n",
    "        img = target.clone().cpu().squeeze()\n",
    "        img = denorm(img.data).clamp_(0, 1)\n",
    "        torchvision.utils.save_image(img, 'output-%d.png' % (epcho + 1))\n",
    "\n",
    "print(\"Done!\")"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [conda root]",
   "language": "python",
   "name": "conda-root-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
